{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import time\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "import mmengine\n",
    "from torch.utils.data import DataLoader\n",
    "os.chdir('../')\n",
    "from new_model.model import TrajectoryModel\n",
    "from new_tools.dataset import TrajectoryDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TrajectoryModel(\n",
       "  (encoder): Encoder(\n",
       "    (self_embedding): ModuleList(\n",
       "      (0-2): 3 x Linear(in_features=80, out_features=128, bias=True)\n",
       "    )\n",
       "    (nei_embedding): ModuleList(\n",
       "      (0-3): 4 x Linear(in_features=32, out_features=128, bias=True)\n",
       "    )\n",
       "  )\n",
       "  (decoder): Decoder(\n",
       "    (decoderLayers): ModuleList(\n",
       "      (0-1): 2 x DecoderLayer(\n",
       "        (attns): ModuleList(\n",
       "          (0-1): 2 x TransformerBlock(\n",
       "            (attn): MultihHeadAttention(\n",
       "              (w_key): Linear(in_features=128, out_features=128, bias=True)\n",
       "              (w_query): Linear(in_features=128, out_features=128, bias=True)\n",
       "              (w_value): Linear(in_features=128, out_features=128, bias=True)\n",
       "              (fc_out): Linear(in_features=128, out_features=128, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (norm1): LayerNorm()\n",
       "            (norm2): LayerNorm()\n",
       "            (feed_forward): FeedForwardLayer(\n",
       "              (w1): Linear(in_features=128, out_features=256, bias=True)\n",
       "              (w2): Linear(in_features=256, out_features=128, bias=True)\n",
       "            )\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pred_head): ModuleList(\n",
       "      (0-2): 3 x Linear(in_features=128, out_features=48, bias=True)\n",
       "    )\n",
       "    (socre_head): ModuleList(\n",
       "      (0-2): 3 x Linear(in_features=128, out_features=1, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_path = './data/CODA/'\n",
    "dataset_name = 'coda'\n",
    "config = yaml.load(open('./configs/coda.yaml', 'r'), Loader=yaml.FullLoader)\n",
    "train_dataset = TrajectoryDataset(dataset_path, dataset_name, 'train', class_balance=True)\n",
    "val_dataset = TrajectoryDataset(dataset_path, dataset_name, 'val')\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4, collate_fn=train_dataset.collate_fn)\n",
    "val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=4, collate_fn=val_dataset.collate_fn)\n",
    "model = TrajectoryModel(num_class=config['num_class'], in_size=2, obs_len=config['obs_len'], pred_len=config['pred_len'], \n",
    "                        embed_size=config['embed_size'], num_decode_layers=config['num_decode_layers'], scale=0.1)\n",
    "model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for data_batch in val_dataloader:\n",
    "        data_batch = [tensor.cuda() for tensor in data_batch]\n",
    "        obs, futures, neis, nei_masks, self_labels, nei_labels, refs, rot_mats = data_batch\n",
    "        preds, scores, init_traj = model(obs, neis, nei_masks, self_labels, nei_labels)\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([128, 49, 24, 2]),\n",
       " torch.Size([128, 49]),\n",
       " torch.Size([3, 49, 24, 2]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds.shape, scores.shape, init_traj.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train()\n",
    "for data_batch in train_dataloader:\n",
    "    data_batch = [tensor.cuda() for tensor in data_batch]\n",
    "    obs, futures, neis, nei_masks, self_labels, nei_labels, refs, rot_mats = data_batch\n",
    "    preds, scores, init_traj = model(obs, neis, nei_masks, self_labels, nei_labels)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs, futures, neis, nei_masks, self_labels, nei_labels, refs, rot_mats = data_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs.shape, futures.shape, neis.shape, nei_masks.shape, self_labels.shape, nei_labels.shape, refs.shape, rot_mats.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds.size(-2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = TrajectoryDataset(dataset_path, dataset_name, 'train', class_balance=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num = [0, 0, 0]\n",
    "for data in train_dataset.data_list:\n",
    "    num[data['label']] += 1\n",
    "num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ros_perception",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
